Skip to content

fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#28

Closed
aviruthen wants to merge 3 commits intomasterfrom
fix/bug-pipeline-parameters-parameterinteger-5504
Closed

fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#28
aviruthen wants to merge 3 commits intomasterfrom
fix/bug-pipeline-parameters-parameterinteger-5504

Conversation

@aviruthen
Copy link
Copy Markdown
Owner

Description

PipelineVariable Support in ModelTrainer Fields (GH#5524)

This PR ensures that ModelTrainer fields that accept StrPipeVar (Union of str and PipelineVariable) work correctly when PipelineVariable objects (e.g., ParameterString) are passed.

Changes

  • sagemaker-train/src/sagemaker/train/utils.py: Updated _get_repo_name_from_image to handle PipelineVariable objects gracefully by returning a default name instead of attempting string operations on non-string types.

  • sagemaker-train/src/sagemaker/train/model_trainer.py: Updated _validate_training_image_and_algorithm_name to properly detect PipelineVariable instances as truthy values during validation, since PipelineVariable objects may not support standard boolean coercion.

Testing

Verified with unit tests in test_model_trainer_pipeline_variable.py that:

  • training_image, algorithm_name, training_input_mode accept ParameterString
  • environment dict values accept ParameterString
  • Plain string values continue to work (regression tests)
  • Invalid types (e.g., int) are still rejected

Related Issue

Related issue: 5504

Changes Made

No response from agent

AI-Generated PR

This PR was automatically generated by the PySDK Issue Agent.

  • Confidence score: 0%
  • Classification: bug
  • SDK version target: V3

Merge Checklist

  • Changes are backward compatible
  • Commit message follows prefix: description format
  • Unit tests added/updated
  • Integration tests added (if applicable)
  • Documentation updated (if applicable)

Copy link
Copy Markdown
Owner Author

@aviruthen aviruthen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes PipelineVariable support in ModelTrainer but has several issues: removed type annotations without replacement, a magic string constant, and the test file referenced in the description is not included in the diff. The approach is reasonable but needs refinement.


def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self, training_image, algorithm_name
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotations: The type annotations for training_image and algorithm_name were removed entirely. Per SDK coding standards (PEP 484), all public/private methods must retain type annotations. Since these parameters now accept both str and PipelineVariable, please use the appropriate union type:

def _validate_training_image_and_algorithm_name(
    self, training_image: str | PipelineVariable | None, algorithm_name: str | PipelineVariable | None
):

Or if StrPipeVar is already defined as a type alias in the codebase, use that.

Comment thread sagemaker-train/src/sagemaker/train/model_trainer.py

Args:
image (str): The image URI
image: The image URI (str or PipelineVariable)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed type annotation: Same issue here — the image parameter lost its type annotation. Please restore it with the correct union type:

def _get_repo_name_from_image(image: str | PipelineVariable) -> str:

Comment thread sagemaker-train/src/sagemaker/train/utils.py
Comment thread sagemaker-train/src/sagemaker/train/utils.py

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self, training_image, algorithm_name
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test file in the diff: The PR description references test_model_trainer_pipeline_variable.py but this file is not included in the changed files. Please ensure the test file is included in the PR. Without tests, we cannot verify the fix works or guard against regressions.

Copy link
Copy Markdown
Owner Author

@aviruthen aviruthen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes PipelineVariable support in ModelTrainer by handling non-string PipelineVariable objects in validation and utility functions. The approach is reasonable, but there are some issues: a line exceeds 100 characters, test fixtures have significant duplication that should be extracted, and the from __future__ import annotations import is missing in favor of the older from __future__ import absolute_import.

Comment thread sagemaker-train/src/sagemaker/train/model_trainer.py
Comment thread sagemaker-train/src/sagemaker/train/model_trainer.py
Comment thread sagemaker-train/src/sagemaker/train/model_trainer.py
mock_session = MagicMock()
mock_session.boto_region_name = "us-east-1"
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock session and mock_defaults setup is duplicated across 4 test methods (test_training_image_accepts_parameter_string, test_algorithm_name_accepts_parameter_string, test_environment_values_accept_parameter_string, test_plain_string_values_still_work). Extract this into a @pytest.fixture to reduce duplication and improve maintainability:

@pytest.fixture
def mock_session():
    session = MagicMock()
    session.boto_region_name = "us-east-1"
    session.default_bucket.return_value = "my-bucket"
    session.default_bucket_prefix = None
    return session

And similarly for the mock_defaults patching.

Comment thread sagemaker-train/tst/unit/sagemaker/train/test_model_trainer_pipeline_variable.py Outdated
Comment thread sagemaker-train/src/sagemaker/train/utils.py
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test case for _validate_training_image_and_algorithm_name where one argument is a PipelineVariable and the other is None — this is the primary success case the fix enables. The current tests test_training_image_accepts_parameter_string and test_algorithm_name_accepts_parameter_string test this indirectly through full ModelTrainer construction, but a direct unit test of the validation method (like the rejection tests at lines 199-238) would be more focused and faster.

Copy link
Copy Markdown
Owner Author

@aviruthen aviruthen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes PipelineVariable support in ModelTrainer by handling PipelineVariable objects that don't support standard boolean coercion. The approach is reasonable, but there are a few issues: a duplicate import in utils.py, the validation logic could be simplified, and the test file has imports inside test methods rather than at module level.

from datetime import datetime
from typing import Literal, Any

from typing import Union
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate import: Union is imported here from typing, but there's already a from typing import Literal, Any on line 24. Consolidate into a single import statement:

from typing import Literal, Any, Union

Also, since the module already imports PipelineVariable from sagemaker.core.workflow.parameters on line 30, and from __future__ import annotations is not present, consider adding it to enable PEP 604 union syntax (str | PipelineVariable) per SDK conventions.

# PipelineVariable objects do not support standard boolean coercion
# (__bool__ raises TypeError), so we use isinstance checks to detect
# them as truthy values during validation.
has_image = isinstance(training_image, PipelineVariable) or bool(training_image)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic isinstance(training_image, PipelineVariable) or bool(training_image) will raise TypeError if training_image is a PipelineVariable that doesn't support __bool__ — but since isinstance short-circuits via or, this is actually safe. However, consider simplifying to:

has_image = training_image is not None and training_image != ""
has_algo = algorithm_name is not None and algorithm_name != ""

This avoids calling bool() entirely and is more explicit about what "not provided" means (None or empty string). The is not None check naturally handles PipelineVariable objects correctly.



def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image: Union[str, PipelineVariable]) -> str:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation says -> str but when a PipelineVariable is passed, it returns a string placeholder, so the annotation is technically correct. However, consider documenting in the docstring that the placeholder _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER is returned for PipelineVariable inputs, so downstream callers understand the behavior.

from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

param = ParameterString(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The from sagemaker.train.model_trainer import ModelTrainer import is repeated inside every test method in this class. Move it to the top of the file with the other imports. Inline imports in tests add unnecessary noise and are not consistent with SDK test conventions.

_TEST_IMAGE_URI = (
"683313688378.dkr.ecr.us-east-1.amazonaws.com/"
"sagemaker-xgboost:1.0-1-cpu-py3"
)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This test image URI contains a hardcoded region (us-east-1) and account ID (683313688378). While this is acceptable for unit tests since it's just a string constant and not used to make actual API calls, consider using a clearly fake account ID (e.g., 123456789012) for consistency with the mock session fixture below.

assert trainer.training_image == _TEST_IMAGE_URI

def test_validation_accepts_pipeline_variable_image_none_algo(self):
"""Test validation accepts PipelineVariable image with None algorithm."""
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ModelTrainer.__new__(ModelTrainer) to bypass __init__ and directly test the validation method is fragile — it creates an uninitialized object. If _validate_training_image_and_algorithm_name ever accesses self attributes, these tests will break with confusing errors. Consider either:

  1. Making _validate_training_image_and_algorithm_name a @staticmethod (it doesn't use self), or
  2. Using the existing mock_train_defaults fixture to construct a proper instance and test through the public interface.


class TestSafeSerializeWithPipelineVariable:
"""Tests for safe_serialize handling of PipelineVariable objects."""

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TestSafeSerializeWithPipelineVariable tests verify safe_serialize behavior with PipelineVariable, but the PR diff doesn't show any changes to safe_serialize. If safe_serialize already handled PipelineVariable correctly, these tests are documenting existing behavior (which is fine), but it would be good to note that in the test class docstring. If safe_serialize needed changes, those changes should be included in this PR.

@aviruthen aviruthen closed this Mar 26, 2026
@aviruthen aviruthen deleted the fix/bug-pipeline-parameters-parameterinteger-5504 branch March 26, 2026 23:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant